#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mar 5 2023

@author: Mobin Asri

- This script is for fitting a statistical model to coverage counts data, which is 
- inspired from Jordan Eizenga's script. The original model contained a mixture of 
- Gaussian distributions and an exponential component (for modeling errors). However 
- the model in this script contains a mixture of Negative Binomial distributions 
- with no exponential component
"""



from copy import deepcopy
from math import log, sqrt
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import nbinom
from scipy.special import digamma
from math import log
from collections import defaultdict
import argparse
import os
from matplotlib.lines import Line2D




class CoverageDistribution:

    """
        Attributes: 
            - cov_err:		The mean of the first component (the error component whose parameters are not tied to other compoents)
            - var_err:		The variance of the first component (the error component whose parameters are not tied to other compoents)
            - cov_per_ploidy:	The mean of the dominant component that usually represents the correctly assembled blocks
            - var_per_ploidy:	The variance of the dominant component that usually represents the correctly assembled blocks
            - n_comps:		Total number of components in the model (each collasped subcomponent is counted one)
            - mixture_weights:	A list of numbers summing to one showing the weights of components (length of the list is equal to n_comps)
            - ploidies: 	A list of ploidies showing how the parameters of the components (other than erroneous component) are tied together
                     		For example if false duplication is modeled the ploidies can be [0.5, 1, 2, 3, 4, ..., N]
                                if false duplication is not modeled the ploidies can be  [1, 2, 3, 4, ..., N] 
                                N is for the last collapsed component determined by the largest coverage present in the input
                  		(length of the list is equal to (n_comps - 1))
        Notes:
            - Since the means and variances of all components (except error component) are tied together it is enough to keep the parameters
              of the haploid component. These parameters can be saved as (r, theta) which are more common for Negative Binomial distributions
              or it can be saved as (mean, var). It is chosen to keep them as (mean, var) to emphasize on how the parameters are tied however during the 
              EM algorithm they are converted to (r,theta) and then converted back to (mean, var) after each round of update. (Look at the internal method update_params_em())
            - The model can be fit to the coverage counts data using the method CoverageDistribution.fit()      
    """
    def __init__(self, cov_err, var_err, cov_per_ploidy, var_per_ploidy, n_comps, mixture_weights, ploidies, pseudo_cov = 1):
        self.cov_err = cov_err
        self.var_err = var_err
        self.cov_per_ploidy = cov_per_ploidy
        self.var_per_ploidy = var_per_ploidy
        self.n_comps = n_comps
        self.mixture_weights = mixture_weights
        self.ploidies = ploidies
        self.pseudo_cov = pseudo_cov

    """
        Print the current values for the parameters of the model
    """
    def __repr__(self):
         w_str = ""
         for i in range(min(self.n_comps, 5)):
             w_str += f"{self.mixture_weights[i]:.2e},"
         out = "\n" + "-" * 20 + "\n"
         out += f"cov_err\tvar_err\tcov_per_ploidy\tvar_per_ploidy\tn_comps\tmixture_weights_first_5\n"
         out += f"{self.cov_err:.2e}\t{self.var_err:.2e}\t{self.cov_per_ploidy:.2e}\t{self.var_per_ploidy:.2e}\t{self.n_comps}\t{w_str}\n"
         return out
   
    """
        Arguments:
            - cov: a coverage value
        Returns:
            - A list that contains the probability of cov being generated by each component
    """
    def component_probabilites(self, cov):
        theta, r = self.get_theta_and_r_values()
        w = self.get_weights()
        K = self.n_comps

        ti = np.zeros(K)
        for k in range(K):
            ti[k] = w[k] * nbinom.pmf(cov, r[k], theta[k])
        ti /= sum(ti)
        
        return ti


    """
        Arguments:
            - cov: a coverage value
        Returns:
            - The probability of cov being generated by the "erroneous" component
    """
    def probability_erroneous(self, cov):
        probs = self.component_probabilites(cov)
        return probs[0]
     
      
    """
        Arguments:
            - cov: a coverage value
        Returns:
            - The probability of cov being generated by the "falsely duplicated" component
    """
    def probability_duplicated(self, cov):
        assert(self.contain_duplication())
        probs = self.component_probabilites(cov)
        return probs[1]

    """
        Arguments:
            - cov: a coverage value
        Returns:
            - The probability of cov being generated by the "haploid" component
    """
    def probability_haploid(self, cov):
        probs = self.component_probabilites(cov)
        if self.contain_duplication():
            return probs[2]
        else:
            return probs[1]

    
    """
        Arguments:
            - cov: a coverage value
        Returns:
            - The probability of cov being generated by the "collapsed" components
    """
    def probability_collapsed(self, cov):
        probs = self.component_probabilites(cov)
        if self.contain_duplication():
            return np.sum(probs[3:])
        else:
            return np.sum(probs[2:])

    """
        Returns:
            - True if the model has a component for false duplication (with ploidy = 0.5) otherwise False
    """
    def contain_duplication(self):
        return self.ploidies[0] == 0.5


    """ 
        This is the core method for performing the EM algorithm
        It is based on these two papers
        https://iopscience.iop.org/article/10.1088/1742-6596/1324/1/012093/meta
        https://onlinelibrary.wiley.com/doi/pdf/10.1111/1467-842X.00075
        
        Arguments:
            - cov_counts: an array of coverage counts
        Returns:
            - A new CoverageDistribution object with the updated prameters
    """
    def update_params_em(self, cov_counts_):

        cov_counts = deepcopy(cov_counts_)
        cov_counts += self.pseudo_cov

        theta, r = self.get_theta_and_r_values()
        w = self.get_weights()
        ploidies = self.ploidies
        K = self.n_comps


        ti = np.zeros((K, len(cov_counts)))
        delta = np.zeros((K, len(cov_counts)))


        for k in range(K):
            #print(k, r[k], theta[k])
            ti[k] = w[k] * nbinom.pmf(np.arange(len(cov_counts)), r[k], theta[k])
            delta[k] = r[k] * (digamma(r[k] + np.arange(len(cov_counts))) - digamma(r[k]))
        ti /= sum(ti[:,:])

        w_next = np.zeros(K)
        beta = np.zeros(K)
        for k in range(K):
            beta[k] = -1 * theta[k] / (1 - theta[k]) - 1 /log(theta[k])
        for k in range(K):
            for t in range(len(cov_counts)):
                w_next[k] += ti[k][t] * cov_counts[t]



        theta_next = np.zeros(K)
        lambda_next = np.zeros(K)
        r_next = np.zeros(K)

        ti_q = ti[1:K]
        delta_q = delta[1:K]
        beta_q = beta[1:K]

        # Update component weights
        w_next /= sum(w_next)

        # Keep previous estmation if weight of the error component is less than 1e-6
        if w_next[0] <= 1e-6:
            w_next[0] = 1e-6
            theta_next[0] = theta[0]
            r_next[0] = r[0]

        if w_next[0] > 1e-6:
            # Update parameters for the first component (usually captures errors)
            lambda_next[0] = np.sum(ti[0] * delta[0] * cov_counts) / np.sum(ti[0] * cov_counts)
            #print(beta[0], ti[0], delta[0])
            theta_next[0] = beta[0] * np.sum(ti[0] * delta[0] * cov_counts) / np.sum(ti[0] * cov_counts * (np.arange(len(cov_counts)) + delta[0] * (beta[0] - 1)))
            if (theta_next[0] > 1 - 1e-10): #and (theta_next[0] > 1):
                theta_next[0] = 1 - 1e-10
            r_next[0] = -1 * lambda_next[0] / np.log(theta_next[0])
        #print(theta_next[0])#, ti[0], delta[0], beta[0] )

        # Update parameters for the remaining components (for capturing false duplication, correctly assembled and collapsed regions)
        lambda_next[1:K] = np.sum(ti_q * delta_q * cov_counts) / np.sum(ti_q * cov_counts * ploidies.reshape(K-1, 1)) * ploidies
        theta_next[1:K] = np.sum(ti_q * delta_q * cov_counts * beta_q.reshape(K-1, 1)) / np.sum(ti_q * cov_counts * (np.arange(len(cov_counts)) + delta_q * (beta_q.reshape(K-1,1) - 1)))
        r_next[1:K] = -1 * lambda_next[1:K] / np.log(theta_next[1:K])

        #print(theta_next, r_next, self.n_comps)
        # get updated mean and var
        # make a new object for updated parameters
        cov_err, var_err = CoverageDistribution.get_mean_and_var(theta_next[0], r_next[0])
        mean_1, var_1 = CoverageDistribution.get_mean_and_var(theta_next[1], r_next[1])
        cov_per_ploidy = mean_1 / ploidies[0]
        var_per_ploidy = var_1 / ploidies[0]
        return CoverageDistribution(cov_err, var_err, cov_per_ploidy, var_per_ploidy, K, w_next, ploidies)



    # Fit and return a CoverageDistribution using a coverage frequency histogram,
    # which should be provided as either a collections.Counter or dict whose keys
    # are coverages and whose values are the number of bases that have that
    # coverage        
    @staticmethod
    def fit(spectrum, contain_dup = True, cov = None, tol = 1e-5, init_iters = 10, max_acc_iters = 1000, max_post_iters=1000, combined_rate_change = 0.5):
        
        # copy the spectrum so we can modify it without breaking expectations
        spectrum = deepcopy(spectrum)
        
        # get a coarse estimate that we can use as a starting point for EM
        cov_dist = CoverageDistribution.heuristic_fit(spectrum, contain_dup, cov)
        print ("Initial:")
        print(cov_dist)
        
        init_iters_actual = 0
        acc_iters_actual = 0
        post_iters_actual = 0

        ll = []
        i = 0
        while i < init_iters:
            cov_dist_new = cov_dist.update_params_em(spectrum)
            ll.append(cov_dist_new.loglikelihood(spectrum))
            print(f"iteration {i} (initial)")
            print(cov_dist_new)
            cov_dist = cov_dist_new
            i += 1
        init_iters_actual = i


        i = 0
        while i < max_acc_iters:
            cov_dist_0 = cov_dist
            #print("first")
            cov_dist_1 = cov_dist_0.update_params_em(spectrum)
            #print("second")
            cov_dist_2 = cov_dist_1.update_params_em(spectrum)
            rates = cov_dist_0.computeRates(cov_dist_1, cov_dist_2)

            cov_dist_prime = cov_dist_0.update_params_squarem(rates)
            while not cov_dist_prime.isFeasible() or cov_dist_0.loglikelihood(spectrum) > cov_dist_prime.loglikelihood(spectrum):
                #print(cov_dist_2)
                if rates['alpha'] > (-1 - 1e-10):
                    cov_dist_prime = cov_dist_0
                    break
                #print(cov_dist_prime.isFeasible(), rates['alpha'])
                rates['alpha'] = (rates['alpha'] - 1) / 2
                cov_dist_prime = cov_dist_0.update_params_squarem(rates)
            cov_dist_new = cov_dist_prime.update_params_em(spectrum)
            ll.append(cov_dist_new.loglikelihood(spectrum))
            print(f"iteration {i} (accerelation)")
            print(cov_dist_new)
            if cov_dist.converged(cov_dist_new, tol):
                cov_dist = cov_dist_new
                i += 1
                break
            if i > 10:
                 recent_change = ll[len(ll) - 1] - ll[len(ll) - 6]
                 old_change = ll[len(ll) - 6] - ll[len(ll) - 11] 
                 if (recent_change / old_change) < combined_rate_change:
                     cov_dist = cov_dist_new
                     i += 1
                     break
            cov_dist = cov_dist_new
            i += 1
        acc_iters_actual = i        

        i = 0
        while i < max_post_iters:
            cov_dist_new = cov_dist.update_params_em(spectrum)
            ll.append(cov_dist_new.loglikelihood(spectrum))
            print(f"iteration {i} (after acceleration)")
            print(cov_dist_new)
            if cov_dist.converged(cov_dist_new, tol):
                cov_dist = cov_dist_new
                i += 1
                break
            cov_dist = cov_dist_new
            i += 1
        post_iters_actual = i

        return cov_dist, np.array(ll), init_iters_actual, acc_iters_actual, post_iters_actual 

    @staticmethod
    def heuristic_fit(cov_counts, contain_dup = True, cov=None):
       
        var_factor = 1.1 # this is the ratio of (variance / mean) for each NB component, which should be greater than 1.0
        max_cov = len(cov_counts)
        
        if cov == None:
            cov_at_max_frequency = None
            max_frequency = -1
            has_increased = False
            for cov_value in np.arange(1, len(cov_counts) - 1):
                if cov_counts[cov_value] > cov_counts[cov_value - 1]:
                    has_increased = True
                if cov_counts[cov_value] > max_frequency and has_increased:
                    max_frequency = cov_counts[cov_value]
                    cov_at_max_frequency = cov_value
        
            cov_per_ploidy = cov_at_max_frequency
            var_per_ploidy = cov_per_ploidy * var_factor
        else:
            cov_per_ploidy = cov
            var_per_ploidy = cov * var_factor
        
        max_ploidy = int(max_cov / cov_per_ploidy)
        while max_ploidy * cov_per_ploidy - 3.0 * sqrt(max_ploidy * cov_per_ploidy) < max_cov:
            max_ploidy += 1
            
        # start with an arbitrary, small mean
        cov_err = 1.0
        var_err = cov_err * var_factor
        
        # the denominator along with pseudocounts
        total_count = sum(cov_counts[cov] for cov in np.arange(len(cov_counts))) + max_ploidy + 1
        
        # approximate the mixture weights simplistically by hard-assigning all data
        # points to their nearest component
        mixture_weights = []

        # if false duplication should be modeled then:
        #          - The 0-th comp takes the interval (Err)          :  [0, 0.25 * cov_per_ploidy)
        #          - The 1-th comp takes the interval (Dup)          :  (0.25 * cov_per_ploidy, 0.75 * cov_per_ploidy]
        #          - The 2-nd comp takes the interval  (Hap)         :  (0.75 * cov_per_ploidy, 1.5 * cov_per_ploidy]
        #          - The i-th comp (i >= 3) takes the interval (Col) :  ((i - 1.5) * cov_per_ploidy, (i - 0.5) * cov_per_ploidy]
        # otherwise:
        #          - The 0-th comp takes the interval  (Err)                                            : [0, 0.5 * cov_per_ploidy)
        #          - The i-th (i >= 1) comp takes the interval (i > 1) takes the interval (Hap and Col) : ((i - 0.5) * cov_per_ploidy, (i + 0.5) * cov_per_ploidy]
        if contain_dup: # duplication is modeled
             n_comps = max_ploidy + 2 # (+1 for the duplication comp, +1 for the error comp)
             intervals = [(0, 0.25), (0.25, 0.75), (0.75, 1.5)]
             ploidies = [0.5, 1] # The error component does not have ploidy
             for i in range(3, n_comps):
                 intervals.append((i - 1.5, i - 0.5))
                 ploidies.append(i - 1)
        else: # duplication is not modeled
             n_comps = max_ploidy + 1 # (+1 for the error comp)
             intervals = [(0,0.5)]
             ploidies = []
             for i in range(1, n_comps):
                 intervals.append((i - 0.5, i + 0.5))
                 ploidies.append(i)

        for i1 , i2 in intervals:
            count = 1
            for cov in range(int(i1 * cov_per_ploidy), int(i2 * cov_per_ploidy)):
                if cov >= len(cov_counts):
                    break
                count += cov_counts[cov]
            mixture_weights.append(count / total_count)

        ploidies = np.array(ploidies)
        mixture_weights = np.array(mixture_weights)
        return CoverageDistribution(cov_err, var_err, cov_per_ploidy, var_per_ploidy, n_comps, mixture_weights, ploidies)


    """
        A method for reparameterizing negative binomial distribution from mean and variance to theta and r       
    """
    @staticmethod
    def get_theta_and_r(mean, var):
        theta = mean / var
        r = mean ** 2 / (var - mean)
        return theta, r


    """
        A method for reparameterizing negative binomial distribution from theta and r to mean and variance       
    """
    @staticmethod
    def get_mean_and_var(theta, r):
        mean = r * (1 - theta) / theta
        var = r * (1 - theta) / theta ** 2
        return mean, var


    """
       A method for getting a list of theta and r values for all components
    """
    def get_theta_and_r_values(self):

       theta_0, r_0 = CoverageDistribution.get_theta_and_r(self.cov_per_ploidy, self.var_per_ploidy)
       theta_err, r_err = CoverageDistribution.get_theta_and_r(self.cov_err, self.var_err)

       r = np.zeros(self.n_comps)
       theta = np.zeros(self.n_comps)
       
       theta[0] = theta_err
       r[0] = r_err
       theta[1:self.n_comps] = theta_0 
       r[1:self.n_comps] = self.ploidies * r_0

       return theta, r



    def loglikelihood(self, cov_counts):
        p = np.zeros((self.n_comps, sum(cov_counts>0)))
        theta, r = self.get_theta_and_r_values()
        
        covs_non_zero = np.arange(len(cov_counts))[cov_counts > 0]
        counts_non_zero = cov_counts[cov_counts > 0]

        for k in range(self.n_comps):
            p[k] = nbinom.pmf(covs_non_zero, r[k], theta[k]) * self.get_weights()[k]
        p_sum = np.sum(p,axis=0)
        return np.sum(np.log(p_sum) * counts_non_zero)


    def computeRates(self, cov_dist_1, cov_dist_2):
        theta_0, r_0 = self.get_theta_and_r_values()
        theta_1, r_1 = cov_dist_1.get_theta_and_r_values()
        theta_2, r_2 = cov_dist_2.get_theta_and_r_values()
        w_0 = self.get_weights()
        w_1 = cov_dist_1.get_weights()
        w_2 = cov_dist_2.get_weights()
        
        r_r = r_1 - r_0
        v_r = r_2 - r_1 - r_r

        r_theta = theta_1 - theta_0
        v_theta = theta_2 - theta_1 - r_theta

        r_weight = w_1 - w_0
        v_weight = w_2 - w_1 - r_weight


        rates = {'alpha': 0, 'theta': {}, 'r': {}, 'weight': {}}
        rates['alpha'] = -1 * np.sqrt(np.sum(r_r ** 2 + r_theta ** 2 + r_weight ** 2) / np.sum(v_r ** 2 + v_theta ** 2 + v_weight ** 2))
        rates['theta']['v'] = v_theta
        rates['theta']['r'] = r_theta
        rates['r']['v'] = v_r
        rates['r']['r'] = r_r
        rates['weight']['v'] = v_weight
        rates['weight']['r'] = r_weight

        return rates


    def update_params_squarem(self, rates):
        theta_next, r_next = self.get_theta_and_r_values()
        w_next = self.get_weights()

        theta_next += -2 * rates['alpha'] * rates['theta']['r'] + rates['alpha'] ** 2 * rates['theta']['v']
        r_next += -2 * rates['alpha'] * rates['r']['r'] + rates['alpha'] ** 2 * rates['r']['v']
        w_next += -2 * rates['alpha'] * rates['weight']['r'] + rates['alpha'] ** 2 * rates['weight']['v']

        # get updated mean and var
        # make a new object for updated parameters
        cov_err, var_err = CoverageDistribution.get_mean_and_var(theta_next[0], r_next[0])
        mean_1, var_1 = CoverageDistribution.get_mean_and_var(theta_next[1], r_next[1])
        cov_per_ploidy = mean_1 / self.ploidies[0]
        var_per_ploidy = var_1 / self.ploidies[0]
        return CoverageDistribution(cov_err, var_err, cov_per_ploidy, var_per_ploidy, self.n_comps, w_next, self.ploidies)


    def isFeasible(self):
        theta, r = self.get_theta_and_r_values()
        w = self.get_weights()

        theta_pos = np.all(theta > 0)
        theta_lt_1 = np.all(theta < 1)
        r_pos = np.all(r > 0)
        weight_pos = np.all(w > 0)
        weight_lt_1 = np.all(w < 1)

        return theta_pos and theta_lt_1 and r_pos and weight_pos and weight_lt_1

    def get_weights(self):
        return deepcopy(self.mixture_weights)

    def converged(self, other, tol):
        if other.cov_per_ploidy != 0.0 and abs(self.cov_per_ploidy / other.cov_per_ploidy - 1.0) > tol:
            return False
        if other.var_per_ploidy != 0.0 and abs(sqrt(self.var_per_ploidy / other.var_per_ploidy) - 1.0) > tol:
            return False
        if other.cov_err != 0.0 and abs(self.cov_err / other.cov_err - 1.0) > tol:
            return False
        if other.var_err != 0.0 and abs(self.var_err / other.var_err - 1.0) > tol:
            return False
        max_ploidy_of_interest = 4
        for wt, other_wt, ploidy in zip(self.mixture_weights, other.mixture_weights, self.ploidies):
            if ploidy > max_ploidy_of_interest:
                break
            if other_wt and abs(wt / other_wt - 1.0) > tol:
                return False
        return True


def simulate_data(max_ploidy, contain_dup, params_str, weights_str, n_obs):
    
    ### simulate observations

    # extract parameters of components
    # variance should be greater than mean for NB distributions 
    cov_err, var_err, cov_per_ploidy, var_per_ploidy = [float(w) for w in params_str.split(",")]
    assert(cov_err < var_err)
    assert(cov_per_ploidy < var_per_ploidy)

    # make a list of mean and variance values for all components of the generator model
    mean = [cov_err]
    ploidies = []
    if contain_dup:
        mean.append(cov_per_ploidy * 0.5)
        ploidies.append(0.5)

    n_comps = max_ploidy + (2 if contain_dup else 1)
    for i in range(1, max_ploidy + 1): # add components other than error and false duplication
        mean.append(i * cov_per_ploidy)
        ploidies.append(i)
    mean = np.array(mean)
    var_factor = var_per_ploidy / cov_per_ploidy # ratio var/mean
    var = mean * var_factor
    ploidies = np.array(ploidies)

    # make a list of mixture weights for all components of the generator model
    weights_float = [float(w) for w in weights_str.split(",")]
    assert(len(weights_float) == 5 if contain_dup else 4)

    mixture_weights = weights_float[:-1]
    col_rest_comps = n_comps - len(mixture_weights)
    # set the weight of the remaining collapsed components all equal
    for i in range(col_rest_comps):
        mixture_weights.append(weights_float[-1] / (col_rest_comps))


    print(max_ploidy, n_comps, mixture_weights, mean)
    # convert mean and variance to r and theta to be able to pass it to nbinom.rvs()
    theta= mean / var
    r= mean ** 2 / (var - mean)

    # generate observations from all components
    # the frequency of generating by each component is
    # determined by the given mixture weights
    obs = []
    for k in range(n_comps):
        obs.extend(nbinom.rvs(r[k], theta[k], size=int(n_obs * mixture_weights[k])))

    # count the number of occurences of each observed value
    cov_counts = np.zeros(max(obs) + 1, dtype=float)
    for x_t in obs:
        cov_counts[x_t] += 1


    # make a truth model for downstream benchmarking 
    model_truth = CoverageDistribution(cov_err, var_err, cov_per_ploidy, var_per_ploidy, n_comps, mixture_weights, ploidies)


    return cov_counts, model_truth


def parse_data(path):
    counts = []
    covs = []
    with open(path, "r") as f:
        for line in f:
            cols = line.strip().split()
            covs.append(int(cols[0]))
            counts.append(float(cols[1]))

    cov_counts = np.zeros(max(covs) + 1, dtype=float)
    for cov,count in zip(covs, counts):
        cov_counts[cov] = count

    return cov_counts


"""
    Plot empirical frequencies, fit and truth probabilities
    It will not draw truth probabilites if model_truth is None (when observations are from real data)
"""
def plot_dist(cov_counts, model_fit, model_truth, figure_x_max, out_dir, info):

    # get r, theta from the fit model
    theta_fit, r_fit = model_fit.get_theta_and_r_values()
    if model_truth:
        theta_truth, r_truth = model_truth.get_theta_and_r_values()

    n_obs = sum(cov_counts)
    x = np.arange(len(cov_counts))

    # The truth model is known if observations are generated by simulation
    if model_truth:
        probs_truth = np.zeros(len(x))
        for k in range(model_truth.n_comps):
            probs_truth += model_truth.mixture_weights[k] * nbinom.pmf(x, r_truth[k], theta_truth[k])
        
    
    probs_fit = np.zeros(len(x))
    for k in range(model_fit.n_comps):
        probs_fit += model_fit.mixture_weights[k] * nbinom.pmf(x, r_fit[k], theta_fit[k])

    # Plot figure
    plt.figure(figsize=(20,12))
    ax = plt.axes([0.1,0.1,0.8,0.8])

    ax.plot(x, cov_counts/n_obs, label = "Observation")

    # Plot the truth probabilites if truth model is provided
    if model_truth:
        ax.plot(x, probs_truth, label="Truth")
    # Plot fit model
    ax.plot(x, probs_fit, label="Fit")
    ax.set_xlabel("Coverage")
    ax.set_xlabel("Probability")
    if figure_x_max == None: figure_x_max = 4 * model_fit.cov_per_ploidy
    ax.set_xlim(0,figure_x_max)

    # Write title
    if model_truth:
    	ax.set_title(f"Fitting a Negative Binomial Mixture Model (NBMM) to counts data \n Counts data is produced by simulation (N={n_obs})")
    else:
        ax.set_title(f"Fitting a Negative Binomial Mixture Model (NBMM) to counts data \n {info}")
    ax.legend()

    os.makedirs(f"{out_dir}", exist_ok=True)
    prefix = os.path.basename(out_dir)
    plt.savefig(f"{out_dir}/{prefix}.fit_model.pdf")
    plt.show()  


def plot_comps(cov_counts, model_fit, contain_dup, simulate, figure_x_max, out_dir, info):
    
    n_obs = sum(cov_counts)
    x = np.arange(len(cov_counts))

    theta_fit, r_fit = model_fit.get_theta_and_r_values()
    probs_fit = np.zeros((model_fit.n_comps, len(x)))
    for k in range(model_fit.n_comps):
        probs_fit[k] = model_fit.mixture_weights[k] * nbinom.pmf(x, r_fit[k], theta_fit[k])


    # Plot figure
    plt.figure(figsize=(20,12))
    ax = plt.axes([0.1,0.1,0.8,0.8])


    ax.plot(x, cov_counts/n_obs, label = "Observation", color="black")
    ax.plot(x, sum(probs_fit), label="Fit", color="blue")
    ax.plot(x, probs_fit[0], label="Err", color="red")
    if contain_dup:
        ax.plot(x, probs_fit[1], label="Dup", color="orange")
        ax.plot(x, probs_fit[2], label="Hap", color="green")
        ax.plot(x, sum(probs_fit[3:]), label="Col", color="purple")
    else:
        ax.plot(x, probs_fit[1], label="Hap", color="green")
        ax.plot(x, sum(probs_fit[2:]), label="Col", color="purple")

    

    # Plot fit model
    ax.set_xlabel("Coverage")
    ax.set_xlabel("Probability")
    if figure_x_max == None: figure_x_max = 4 * model_fit.cov_per_ploidy
    ax.set_xlim(0,figure_x_max)

    # Write title
    if simulate:
    	ax.set_title(f"Fitting a Negative Binomial Mixture Model (NBMM) to counts data \n Counts data is produced by simulation (N={n_obs})")
    else:
        ax.set_title(f"Fitting a Negative Binomial Mixture Model (NBMM) to counts data \n {info}")
    ax.legend()

    os.makedirs(f"{out_dir}", exist_ok=True)
    prefix = os.path.basename(out_dir)
    plt.savefig(f"{out_dir}/{prefix}.fit_model.components.pdf")
    plt.show()  

"""
    Plot log-likelihood w.r.t to effective EM iterations
    Note that each iteration in accelerated mode is considered 
    three effective iterations since we update the 
    parameters three times; two times for computing the rate of change (alpha)
    and once for computing the final update of that iteration
    (Look at CoverageDistribution.fit())
"""
def plot_log_likelihood(ll, init_iters_actual, acc_iters_actual, post_iters_actual, simulate, n_obs, out_dir, info):
    color_init = "black"
    color_acc = "blue"
    color_post = "red"

    ll_effective_iters = np.concatenate([np.arange(init_iters_actual), 
                                         np.arange(acc_iters_actual) * 3 + init_iters_actual + 2, 
                                         np.arange(post_iters_actual) + init_iters_actual + acc_iters_actual * 3])
    ll_colors = np.concatenate([[color_init] * init_iters_actual, 
                                [color_acc] * acc_iters_actual, 
                                [color_post] * post_iters_actual])

    assert(len(ll_effective_iters) == len(ll_colors))

    # Plot figure
    plt.figure(figsize=(20,12))
    ax = plt.axes([0.1,0.1,0.8,0.8])
    ax.set_xlabel("Effective EM iterations")
    ax.set_ylabel("Log Likelihood")
    
    ax.plot(ll_effective_iters, ll, color="gray", label = "log-likelihood", linewidth=1)
    ax.scatter(ll_effective_iters, ll, color=ll_colors, s=20)

    # Write title
    if simulate:
    	ax.set_title(f"Log-likelihood w.r.t to effective EM iterations \n Counts data is produced by simulation (N={n_obs})")
    else:
        ax.set_title(f"Log-likelihood w.r.t to effective EM iterations \n {info}")

    legend_elements = [Line2D([0], [0], marker='o', color='black', label='Initial iterations (pre-acceleration)', markerfacecolor='black', markersize=5),
                       Line2D([0], [0], marker='o', color='blue', label='Accelerated iterations', markerfacecolor='blue', markersize=5),
                       Line2D([0], [0], marker='o', color='red', label='Post-acceleration iterations', markerfacecolor='red', markersize=5),]
    ax.legend(handles=legend_elements, loc='lower right')

    # Save and show figure
    os.makedirs(f"{out_dir}", exist_ok=True)
    prefix = os.path.basename(out_dir)
    plt.savefig(f"{out_dir}/{prefix}.loglikelihood.pdf")
    plt.show()


def write_probabilities(cov_counts, model_fit, simulate, contain_dup, out_dir):
    print(f"Writing output files (output dir = {out_dir}) ...")


    os.makedirs(f"{out_dir}", exist_ok=True)


    # get r, theta from the fit model
    theta_fit, r_fit = model_fit.get_theta_and_r_values()

    # get fit probabilities
    probs_fit = np.zeros(len(cov_counts))
    x = np.arange(len(cov_counts))
    for k in range(model_fit.n_comps):
        probs_fit += model_fit.mixture_weights[k] * nbinom.pmf(x, r_fit[k], theta_fit[k])


    os.makedirs(f"{out_dir}", exist_ok=True)
    prefix = os.path.basename(out_dir)
    if simulate:
        with open(f"{out_dir}/{prefix}.counts","w+") as f:
            for i,cov in enumerate(cov_counts):
                f.write(f"{i}\t{cov}\n")
        

    # Write the observed counts and also the counts fit (could be explained) by the model
    # Write the probability of each coverage being generated by each component
    # False duplication probabilities are written in the output file only if it is modeled
    if contain_dup:
        with open(f"{out_dir}/{prefix}.coverage_component_probs.tsv","w+") as f:
            f.write("#coverage\tfreq\tfit\terror\tduplicated\thaploid\tcollapsed\n")
            for cov in range(len(cov_counts)):
                prob_err = model_fit.probability_erroneous(cov)
                prob_dup = model_fit.probability_duplicated(cov)
                prob_hap = model_fit.probability_haploid(cov)
                prob_col = model_fit.probability_collapsed(cov)
                f.write("{:d}\t{:f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\n".format(cov, cov_counts[cov], probs_fit[cov] * sum(cov_counts), prob_err, prob_dup, prob_hap, prob_col))
    else:
        with open(f"{out_dir}/{prefix}.coverage_component_probs.tsv","w+") as f:
            f.write("#coverage\tfreq\tfit\terror\thaploid\tcollapsed\n")
            for cov in range(len(cov_counts)):
                prob_err = model_fit.probability_erroneous(cov)
                prob_hap = model_fit.probability_haploid(cov)
                prob_col = model_fit.probability_collapsed(cov)
                f.write("{:d}\t{:f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\n".format(cov, cov_counts[cov], probs_fit[cov] * sum(cov_counts), prob_err, prob_hap, prob_col))



"""
    Write a table containing loglikelihood, AIC and BIC values w.r.t effective iterations
"""
def write_loglikelihood(ll, aic, bic, init_iters_actual, acc_iters_actual, post_iters_actual, out_dir):

    ll_effective_iters = np.concatenate([np.arange(init_iters_actual), 
                                         np.arange(acc_iters_actual) * 3 + init_iters_actual + 2, 
                                         np.arange(post_iters_actual) + init_iters_actual + acc_iters_actual * 3])

    os.makedirs(f"{out_dir}", exist_ok=True)
    prefix = os.path.basename(out_dir)
    with open(f"{out_dir}/{prefix}.loglikelihood.tsv","w+") as f:
            f.write("#effective_iteration\tloglikelihood\tAIC\tBIC\n")
            for i, a, b, c in zip(ll_effective_iters, ll, aic, bic):
                f.write(f"{i}\t{a}\t{b}\t{c}\n")


def main():
    parser = argparse.ArgumentParser(description='A script for fitting a negative binomial mixture model (NBMM) to an observed counts data (or simluated data)')
    parser.add_argument('--em-mode', type=str, default="combined",
                    help='There are three modes for the EM algorithm; regular, accelerated, combined (default: combined)')
    parser.add_argument('--cov', type=float, default = None,
                    help='coverage per ploidy (By Default the program determines the coverage per ploidy without this parameter however if it is provided the EM algorithm start points will be adjusted based on this)')
    parser.add_argument('--include-dup-comp', action='store_true', default=False,
                    help='The model will take into account a component for modeling false duplication.')
    parser.add_argument('--counts', type=str,
                    help='The input tab-delimited file that contains observed counts. The first and second columns should contain observation (e.g. coverage) and related count respectively.')
    parser.add_argument('--info', type=str, default="",
                    help='Info string to be shown in the title of the plots [Default = ""]')
    parser.add_argument('--out-dir', type=str, default = "nbmm_test",
                    help='The output dir')
    parser.add_argument('--figure-x-max', type=int,
                    help='The maximum x value to be shown on the output distributin figure.')
    parser.add_argument('--init-iters', type=int, default = 10,
                    help='The number of initial iterations for EM algorithm prior to starting the accelerated mode. [Default = 10]')
    parser.add_argument('--max-acc-iters', type=int, default = 1000,
                    help='The maximum number of iterations for EM algorithm in the accelerated mode (in any iteration if the parameters changed by less than the given tolerance --tol then it will break in that iteration). This parameter will not be used if --mode is "regular" which forces the program not to use the accelerated mode. [Default = 1000]')
    parser.add_argument('--max-post-iters', type=int, default = 1000,
                    help='The maximum number of iterations for EM algorithm after the accelerated mode (in any iteration if the parameters changed by lower than the given tolerance --tol then it will break in that iteration). [Default = 1000]')
    parser.add_argument('--tol', type=float, default = 1e-5,
                    help='In any EM iteration if all parameters changed by less than this tolerance then EM will be terminated in that iteration. In otherwords if abs(param_1 - param_0) / param_0 < tolerance for all parameters then EM will be terminated. [Default = 1e-5]')
    parser.add_argument('--combined-rate-change', type=float, default = 0.5,
                    help='Related to "combined" mode. During accelerated iterations if the rate of increasing log-likelihood decreased by this amount the accelerated mode will stop there and the following regular mode will start. For calculating this rate change (an approximation of 2nd derivative) the log-likelihood values of 10 earlier iterations will be considered. It has been observed that the "accelerated" mode is faster in initial iterations but after a while the "regular" mode will be faster. [Default = 0.5]')
    parser.add_argument('--simulate', action='store_true', default=False,
                    help='Use simulated data instead of real data (--counts will be ignored)')
    parser.add_argument('--simulate-max-ploidy', type=int, default=20,
                    help='Maximum ploidy in simulated data (should be greater than 1)')
    parser.add_argument('--simulate-include-dup-comp', action='store_true', default=False,
                    help='The simulation will generate observations for false duplication. The frequency of each component is determined based on the given weights --simulate-weights)')
    parser.add_argument('--simulate-params', type=str, default= "1,2,30,40",
                    help='A comma-separated string that contains the parameters of model for generating the simulated data; ${mean-err},${var-err},${mean-hap},${var-hap}. [Default = "1,2,30,40"]')
    parser.add_argument('--simulate-weights', type=str, default= "0.05,0.05,0.7,0.1,0.1",
                    help='A comma-separated string that contains the weights of the generator components for the simulated data; ${err-weight},${dup-weight},${hap-weight},${col-weight-first},${col-weight-rest}. If --simulate-include-dup-comp is turned off then ${dup-weight} should be removed from the string. [Default = "0.05,0.05,0.7,0.1,0.1"]')
    parser.add_argument('--simulate-n-obs', type=int, default= 10000,
                    help='Number of simulated observations [Default = 10000]')
    parser.add_argument('--disable-plot', action='store_true',
                    help='Do not output any plot')

    args = parser.parse_args()
    em_mode = args.em_mode
    contain_dup = args.include_dup_comp
    counts_path = args.counts
    out_dir = args.out_dir
    simulate = args.simulate
    simulate_max_ploidy = args.simulate_max_ploidy
    simulate_contain_dup = args.simulate_include_dup_comp
    simulate_params_str = args.simulate_params
    simulate_weights_str = args.simulate_weights
    simulate_n_obs = args.simulate_n_obs
    figure_x_max = args.figure_x_max
    init_iters = args.init_iters
    max_acc_iters = args.max_acc_iters
    max_post_iters = args.max_post_iters
    tol = args.tol
    combined_rate_change = args.combined_rate_change
    info = args.info
    cov = args.cov



    #####################################
    # Parse or generate data / Fit NBMM #
    #####################################

    model_truth=None
    if simulate:
        cov_counts, model_truth = simulate_data(simulate_max_ploidy, simulate_contain_dup, simulate_params_str, simulate_weights_str, simulate_n_obs)
    else:
        cov_counts = parse_data(counts_path)
    n_obs = sum(cov_counts)

    if em_mode not in ["regular", "accelerated", "combined"]:
        print("--em_mode can only be one of these three modes: regular, accelerated, combined")
        exit()

    if em_mode == "regular":
        max_acc_iters = 0
    if em_mode == "accelerated":
        max_post_iters = 0
    
    model_fit, ll, init_iters_actual, acc_iters_actual, post_iters_actual = CoverageDistribution.fit(cov_counts, 
                                                                                                     contain_dup, 
                                                                                                     cov = cov,
                                                                                                     tol = tol, 
                                                                                                     init_iters=init_iters, 
                                                                                                     max_acc_iters=max_acc_iters, 
                                                                                                     max_post_iters=max_post_iters, 
                                                                                                     combined_rate_change=combined_rate_change)

    print(f"init_iters_actual = {init_iters_actual}\nacc_iters_actual = {acc_iters_actual}\npost_iters_actual = {post_iters_actual}")

    # n_comps - 1 = number of independent weight parameters
    # 4 = 2 (error component) + 2 (other components)
    n_params = model_fit.n_comps - 1 + 4
    n_obs = sum(cov_counts)

    final_ll = ll[-1]
    bic = -2 * ll + np.log(n_obs) * n_params
    aic = -2/n_obs * ll + 2 * n_params/n_obs
    
    print("Final log-likelihood:", ll[-1])
    print("Final BIC:", bic[-1])
    print("Final AIC:", aic[-1])


    #################################
    # Plot figures and write tables #
    #################################

    if not args.disable_plot:
        # plot emprical frequencies beside fit (and truth if present) distributions 
        plot_dist(cov_counts, model_fit, model_truth, figure_x_max, out_dir, info)

        plot_comps(cov_counts, model_fit, contain_dup, simulate, figure_x_max, out_dir, info)
    
        # plot loglikelihood w.r.t effective iterations
        plot_log_likelihood(ll, init_iters_actual, acc_iters_actual, post_iters_actual, simulate, n_obs, out_dir, info)

    write_probabilities(cov_counts, model_fit, simulate, contain_dup, out_dir)    
    write_loglikelihood(ll, aic, bic, init_iters_actual, acc_iters_actual, post_iters_actual, out_dir)

 
if __name__ == "__main__":
    main()


